{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chain\n",
    "![Chain](images/chain.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic Graph\n",
    "![Graph](images/graph.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### State\n",
    "A State shows the graph’s current setup and tracks changes over time. It serves as input and output for all Nodes and Edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing_extensions import TypedDict\n",
    "\n",
    "class State(TypedDict):\n",
    "    state: str"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Node\n",
    "a Node represents an individual element or point within the graph, storing data and connecting with other nodes through edges to form relationships.\n",
    "\n",
    "a Node is a Python function, the first arg of the function is a State"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def node_1(state):\n",
    "    print(\"Node 1\")\n",
    "    return {\"state\": state[\"state\"] + \"-1-\"}\n",
    "\n",
    "def node_2(state):\n",
    "    print(\"Node 2\")\n",
    "    return {\"state\": state[\"state\"] + \"-2-\"}\n",
    "\n",
    "def node_3(state):\n",
    "    print(\"Node 3\")\n",
    "    return {\"state\": state[\"state\"] + \"-3-\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Edges\n",
    "an Edge represents a connection between two nodes, defining the relationship or interaction between them and potentially carrying additional data.\n",
    "types:\n",
    "    - normal edges: always go this way (from node_1 to node_2)\n",
    "    - conditional edges: optional route between nodes. a Pythin function that returns a next node based on a logic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from typing import Literal\n",
    "\n",
    "def get_random_node(state) -> Literal[\"node_2\", \"node_3\"]:\n",
    "    current_state = state['state'] # usually the desiction is based on current state\n",
    "    if random.random() < 0.5:\n",
    "        return \"node_2\"\n",
    "    return \"node_3\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Graph Construction and Invocation\n",
    "START and END are special nodes that represent the start and end of the graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "from langgraph.graph import StateGraph, START, END\n",
    "\n",
    "# generate\n",
    "builder = StateGraph(State)\n",
    "builder.add_node(\"node_1\", node_1)\n",
    "builder.add_node(\"node_2\", node_2)\n",
    "builder.add_node(\"node_3\", node_3)\n",
    "\n",
    "# logic\n",
    "builder.add_edge(START, \"node_1\")\n",
    "builder.add_conditional_edges(\"node_1\", get_random_node)\n",
    "builder.add_edge(\"node_2\", END)\n",
    "builder.add_edge(\"node_3\", END)\n",
    "\n",
    "# building\n",
    "graph = builder.compile()\n",
    "\n",
    "# visualize\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.invoke({\"state\" : \"Hi, there!\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Some LLM related concepts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Messages\n",
    "chat models operate on messages.\n",
    "various message types (check out our other Spring AI video):\n",
    "\n",
    "    - HumanMessage\n",
    "    - AIMessage\n",
    "    - SystemMessage\n",
    "    - ToolMessage (will check later)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage\n",
    "\n",
    "# Initial SystemMessage to set context\n",
    "messages = [\n",
    "    SystemMessage(content=\"You are a witty and humorous AI assistant for developers, specializing in automating their daily coding headaches.\"),\n",
    "]\n",
    "\n",
    "messages.append(AIMessage(content=\"Hey there! I’m your AI assistant. Ready to debug your code—or your life?\", name=\"Model\"))\n",
    "messages.append(HumanMessage(content=\"Can you handle all my TODO comments?\", name=\"Dev\"))\n",
    "messages.append(AIMessage(content=\"Sure! I’ll replace 'TODO: Refactor' with 'TODO: Blame someone else.' Problem solved!\", name=\"Model\"))\n",
    "messages.append(HumanMessage(content=\"Nice. Can you optimize my SQL too?\", name=\"Dev\"))\n",
    "\n",
    "for message in messages:\n",
    "    message.pretty_print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Chat models\n",
    "processes messages as prompts and response with completion.\n",
    "- OpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "# OPENAI_API_KEY environment variable must be set\n",
    "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
    "result = llm.invoke(messages)\n",
    "\n",
    "type(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "print(json.dumps(vars(result), indent=4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tools\n",
    "Tools help an AI use special apps or systems to do things it can’t do on its own, like checking the weather or solving tricky problems.\n",
    "\n",
    "![Graph](images/tools.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multiply_values(a, b):\n",
    "    \"\"\"\n",
    "    Multiply two values and return the result.\n",
    "\n",
    "    Parameters:\n",
    "        a (float): The first value.\n",
    "        b (float): The second value.\n",
    "\n",
    "    Returns:\n",
    "        float: The product of a and b.\n",
    "    \"\"\"\n",
    "    return a * b\n",
    "\n",
    "llm_tools = llm.bind_tools([multiply_values])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### How does LLM know which tool to use?\n",
    "- the name of the function\n",
    "- docstring definition\n",
    "- number of arguments\n",
    "- ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_call = llm_tools.invoke([HumanMessage(content=f\"What is 2 multiplied by 3\", name=\"Dev\")])\n",
    "\n",
    "print(json.dumps(vars(tool_call), indent=4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Merging Messages with State\n",
    "with LLM we're ineterested in passing messages between nodes. So they become a part of the state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing_extensions import TypedDict\n",
    "from langchain_core.messages import AnyMessage\n",
    "\n",
    "class MessagesState(TypedDict):\n",
    "    messages: list[AnyMessage]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "the problem with this approach - override of the state.\n",
    "so we need to append messages to the list\n",
    "we will use reducers for changing the way how state is being updated.\n",
    "\n",
    "```python\n",
    "def node_1(state):\n",
    "    print(\"Node 1\")\n",
    "    return {\"state\": state[\"state\"] + \"-1-\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Annotated\n",
    "from langgraph.graph.message import add_messages\n",
    "\n",
    "class MessagesState(TypedDict):\n",
    "    messages: Annotated[list[AnyMessage], add_messages]\n",
    "\n",
    "\n",
    "initial_messages = [SystemMessage(content=\"You are a witty and humorous AI assistant for developers, specializing in automating their daily coding headaches.\")]\n",
    "new_message = AIMessage(content=\"Hey there! I’m your AI assistant. Ready to debug your code—or your life?\", name=\"Model\")\n",
    "add_messages(initial_messages , new_message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import MessagesState\n",
    "\n",
    "class MessagesState(MessagesState):\n",
    "    # Extend to include additional keys beyond the pre-built messages key\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combine all together\n",
    "![Graph](images/tools.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "from langgraph.graph import StateGraph, START, END\n",
    "    \n",
    "# Node\n",
    "def llm_with_tools(state: MessagesState):\n",
    "    return {\"messages\": [llm_tools.invoke(state[\"messages\"])]}\n",
    "\n",
    "# Build graph\n",
    "builder = StateGraph(MessagesState)\n",
    "builder.add_node(\"llm_with_tools\", llm_with_tools)\n",
    "builder.add_edge(START, \"llm_with_tools\")\n",
    "builder.add_edge(\"llm_with_tools\", END)\n",
    "graph = builder.compile()\n",
    "\n",
    "# View\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = graph.invoke({\"messages\": HumanMessage(content=\"How are you?\")})\n",
    "for m in messages['messages']:\n",
    "    m.pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = graph.invoke({\"messages\": HumanMessage(content=\"Multiply 2 and 3\")})\n",
    "for m in messages['messages']:\n",
    "    m.pretty_print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
